package com.chachae.webrtc.netty.socket.standard;

import com.chachae.webrtc.netty.socket.annotation.ServerEndpoint;
import com.chachae.webrtc.netty.socket.exception.DeploymentException;
import com.chachae.webrtc.netty.socket.pojo.PojoEndpointServer;
import com.chachae.webrtc.netty.socket.pojo.PojoMethodMapping;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;
import org.springframework.beans.TypeConverter;
import org.springframework.beans.TypeMismatchException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanExpressionContext;
import org.springframework.beans.factory.config.BeanExpressionResolver;
import org.springframework.beans.factory.support.AbstractBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.ApplicationObjectSupport;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.env.Environment;

/**
 * @author Yeauty
 * @version 1.0
 */
public class ServerEndpointExporter extends ApplicationObjectSupport implements SmartInitializingSingleton, BeanFactoryAware {

  @Autowired
  Environment environment;

  private AbstractBeanFactory beanFactory;

  private final Map<InetSocketAddress, WebsocketServer> addressWebsocketServerMap = new HashMap<>();

  @Override
  public void afterSingletonsInstantiated() {
    registerEndpoints();
  }

  @Override
  public void setBeanFactory(BeanFactory beanFactory) {
    if (!(beanFactory instanceof AbstractBeanFactory)) {
      throw new IllegalArgumentException(
          "AutowiredAnnotationBeanPostProcessor requires a AbstractBeanFactory: " + beanFactory);
    }
    this.beanFactory = (AbstractBeanFactory) beanFactory;
  }

  protected void registerEndpoints() {
    Set<Class<?>> endpointClasses = new LinkedHashSet<>();

    ApplicationContext context = getApplicationContext();
    if (context != null) {
      String[] endpointBeanNames = context.getBeanNamesForAnnotation(ServerEndpoint.class);
      for (String beanName : endpointBeanNames) {
        endpointClasses.add(context.getType(beanName));
      }
    }

    for (Class<?> endpointClass : endpointClasses) {
      registerEndpoint(endpointClass);
    }

    init();
  }

  private void init() {
    for (Map.Entry<InetSocketAddress, WebsocketServer> entry : addressWebsocketServerMap.entrySet()) {
      WebsocketServer websocketServer = entry.getValue();
      try {
        websocketServer.init();
        PojoEndpointServer pojoEndpointServer = websocketServer.getPojoEndpointServer();
        StringJoiner stringJoiner = new StringJoiner(",");
        pojoEndpointServer.getPathMatcherSet().forEach(pathMatcher -> stringJoiner.add("'" + pathMatcher.getPattern() + "'"));
        logger.info(String.format("\033[34mNetty WebSocket started on port: %s with context path(s): %s .\033[0m", pojoEndpointServer.getPort(), stringJoiner.toString()));
      } catch (InterruptedException e) {
        logger.error(String.format("websocket [%s] init fail", entry.getKey()), e);
      }
    }
  }

  private void registerEndpoint(Class<?> endpointClass) {
    ServerEndpoint annotation = AnnotatedElementUtils.findMergedAnnotation(endpointClass, ServerEndpoint.class);
    if (annotation == null) {
      throw new IllegalStateException("missingAnnotation ServerEndpoint");
    }
    ServerEndpointConfig serverEndpointConfig = buildConfig(annotation);

    ApplicationContext context = getApplicationContext();
    PojoMethodMapping pojoMethodMapping;
    try {
      pojoMethodMapping = new PojoMethodMapping(endpointClass, context, beanFactory);
    } catch (DeploymentException e) {
      throw new IllegalStateException("Failed to register ServerEndpointConfig: " + serverEndpointConfig, e);
    }

    InetSocketAddress inetSocketAddress = new InetSocketAddress(serverEndpointConfig.getHost(), serverEndpointConfig.getPort());
    String path = resolveAnnotationValue(annotation.value(), String.class, "path");

    WebsocketServer websocketServer = addressWebsocketServerMap.get(inetSocketAddress);
    if (websocketServer == null) {
      PojoEndpointServer pojoEndpointServer = new PojoEndpointServer(pojoMethodMapping, serverEndpointConfig, path);
      websocketServer = new WebsocketServer(pojoEndpointServer, serverEndpointConfig);
      addressWebsocketServerMap.put(inetSocketAddress, websocketServer);
    } else {
      websocketServer.getPojoEndpointServer().addPathPojoMethodMapping(path, pojoMethodMapping);
    }
  }

  private ServerEndpointConfig buildConfig(ServerEndpoint annotation) {
    String host = resolveAnnotationValue(annotation.host(), String.class, "host");
    int port = resolveAnnotationValue(annotation.port(), Integer.class, "port");
    String path = resolveAnnotationValue(annotation.value(), String.class, "value");
    int bossLoopGroupThreads = resolveAnnotationValue(annotation.bossLoopGroupThreads(), Integer.class, "bossLoopGroupThreads");
    int workerLoopGroupThreads = resolveAnnotationValue(annotation.workerLoopGroupThreads(), Integer.class, "workerLoopGroupThreads");
    boolean useCompressionHandler = resolveAnnotationValue(annotation.useCompressionHandler(), Boolean.class, "useCompressionHandler");

    int optionConnectTimeoutMillis = resolveAnnotationValue(annotation.optionConnectTimeoutMillis(), Integer.class, "optionConnectTimeoutMillis");
    int optionSoBacklog = resolveAnnotationValue(annotation.optionSoBacklog(), Integer.class, "optionSoBacklog");

    int childOptionWriteSpinCount = resolveAnnotationValue(annotation.childOptionWriteSpinCount(), Integer.class, "childOptionWriteSpinCount");
    int childOptionWriteBufferHighWaterMark = resolveAnnotationValue(annotation.childOptionWriteBufferHighWaterMark(), Integer.class, "childOptionWriteBufferHighWaterMark");
    int childOptionWriteBufferLowWaterMark = resolveAnnotationValue(annotation.childOptionWriteBufferLowWaterMark(), Integer.class, "childOptionWriteBufferLowWaterMark");
    int childOptionSoRcvbuf = resolveAnnotationValue(annotation.childOptionSoRcvbuf(), Integer.class, "childOptionSoRcvbuf");
    int childOptionSoSndbuf = resolveAnnotationValue(annotation.childOptionSoSndbuf(), Integer.class, "childOptionSoSndbuf");
    boolean childOptionTcpNodelay = resolveAnnotationValue(annotation.childOptionTcpNodelay(), Boolean.class, "childOptionTcpNodelay");
    boolean childOptionSoKeepalive = resolveAnnotationValue(annotation.childOptionSoKeepalive(), Boolean.class, "childOptionSoKeepalive");
    int childOptionSoLinger = resolveAnnotationValue(annotation.childOptionSoLinger(), Integer.class, "childOptionSoLinger");
    boolean childOptionAllowHalfClosure = resolveAnnotationValue(annotation.childOptionAllowHalfClosure(), Boolean.class, "childOptionAllowHalfClosure");

    int readerIdleTimeSeconds = resolveAnnotationValue(annotation.readerIdleTimeSeconds(), Integer.class, "readerIdleTimeSeconds");
    int writerIdleTimeSeconds = resolveAnnotationValue(annotation.writerIdleTimeSeconds(), Integer.class, "writerIdleTimeSeconds");
    int allIdleTimeSeconds = resolveAnnotationValue(annotation.allIdleTimeSeconds(), Integer.class, "allIdleTimeSeconds");

    int maxFramePayloadLength = resolveAnnotationValue(annotation.maxFramePayloadLength(), Integer.class, "maxFramePayloadLength");

    return new ServerEndpointConfig(host, port, path, bossLoopGroupThreads, workerLoopGroupThreads, useCompressionHandler, optionConnectTimeoutMillis, optionSoBacklog, childOptionWriteSpinCount,
        childOptionWriteBufferHighWaterMark, childOptionWriteBufferLowWaterMark, childOptionSoRcvbuf, childOptionSoSndbuf, childOptionTcpNodelay, childOptionSoKeepalive, childOptionSoLinger,
        childOptionAllowHalfClosure, readerIdleTimeSeconds, writerIdleTimeSeconds, allIdleTimeSeconds, maxFramePayloadLength);
  }

  private <T> T resolveAnnotationValue(Object value, Class<T> requiredType, String paramName) {
    if (value == null) {
      return null;
    }
    TypeConverter typeConverter = beanFactory.getTypeConverter();

    if (value instanceof String) {
      String strVal = beanFactory.resolveEmbeddedValue((String) value);
      BeanExpressionResolver beanExpressionResolver = beanFactory.getBeanExpressionResolver();
      if (beanExpressionResolver != null) {
        value = beanExpressionResolver.evaluate(strVal, new BeanExpressionContext(beanFactory, null));
      } else {
        value = strVal;
      }
    }
    try {
      return typeConverter.convertIfNecessary(value, requiredType);
    } catch (TypeMismatchException e) {
      throw new IllegalArgumentException("Failed to convert value of parameter '" + paramName + "' to required type '" + requiredType.getName() + "'");
    }
  }

}
